-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify fluid api recognize digit #10308
Conversation
BATCH_SIZE = 64 | ||
|
||
|
||
def inference_network(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename inference_network
to inference_program
and train_network
to train_program
? Yi pointed out that "program" is more suitable for Fluid than "network".
acc_val = numpy.array(acc_set).mean() | ||
avg_loss_val = numpy.array(avg_loss_set).mean() | ||
if float(acc_val) > 0.2: # Smaller value to increase CI speed | ||
trainer.params.save(save_dirname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change trainer.params.save
to trainer.save_params
? (https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py#L103)
|
||
def event_handler(event): | ||
if isinstance(event, fluid.EndIteration): | ||
print(event.metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change print(event.metrics)
to avg_cost, acc = event.values
? (e.g., same name as the return value of train_network
), and rename avg_loss_val
and acc_val
below to avg_cost
and acc
for clarity.
params = fluid.Params(save_dirname) | ||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
|
||
inferencer = fluid.Inferencer(inference_network, params, place=place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that the API has changed slightly, please change Inferencer
and Trainer
's constructor according to https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py#L118
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No Problem, will do
8f1c70b
to
e68cc6e
Compare
@helinwang Updated according to review. |
avg_cost = numpy.array(avg_cost_set).mean() | ||
if float(acc) > 0.2: # Smaller value to increase CI speed | ||
trainer.save_params(save_dirname) | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This return
actually does not stop training, perhaps remove it.
Not related to this PR, maybe we need a function called trainer.stop_train
.
return | ||
else: | ||
print( | ||
'BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are 1:04
, 2:2.2
, 3:2.2
typos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! They are from the original code. When I removed the first item in the print, I forgot to change the others. Will fix this.
|
||
else: | ||
print( | ||
'BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are 1:04
, 2:2.2
, 3:2.2
typos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
646e65a
to
c06e1ef
Compare
Update several variable names to make them more consistant. Update the Inferencer construction call.
c06e1ef
to
c10d030
Compare
No description provided.